
from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import datasets, transforms
import numpy as np
import copy
from torch.utils.data import ConcatDataset
from sklearn.model_selection import train_test_split
import os
import torch
from torchvision import datasets, transforms
from utils.DataLoader import DataLoader
from datasets import load_dataset, concatenate_datasets, DatasetDict, ClassLabel, load_from_disk, Dataset
import math
# from torch.utils.data import DataLoader
from PIL import Image



class CustomImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, num_classes=20):
        super(CustomImageFolder, self).__init__(root, transform=transform)

        # 获取所有类别的名称和索引
        # classes, class_to_idx = self.find_classes(root)
        # # 筛选出前num_classes个类别的索引
        # selected_classes = classes[:num_classes]
        # print(selected_classes)
        selected_classes_idx = list(range(num_classes))

        # selected_class_to_idx = {cls: i for i, cls in enumerate(selected_classes)}

        # 筛选出前num_classes个类别的样本
        self.samples = [(path, cls) for path, cls in self.samples if cls in selected_classes_idx]
        self.targets = [s[1] for s in self.samples]


class SubsetDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, indices):
        self.dataset = original_dataset
        self.indices = indices
        self.bd_data = []
        self.bd_flag = False
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        # 如果索引在bd_data范围内，且已经构造bd_data，那么就索引bd_data
        if index < len(self.bd_data) and self.bd_flag:
            return self.bd_data[index]
        # 否则获取原始数据集中对应的索引
        original_index = self.indices[index]

        # 从原始数据集中获取数据
        data, target = self.dataset[original_index]
        return data, target

    def add_trigger(self, bd_maker, attack_portion=0.8):
        self.bd_flag = True
        for i in range(int(len(self) * attack_portion)):
            data = self.dataset[self.indices[i]]
            data = bd_maker.add_backdoor(data)
            self.bd_data.append(data)

class DataLoader_pacs(DataLoader):
    def __init__(self,
                 batch_size=100,
                 split_num=2,
                 class_num=7,
                 input_require_shape=None,
                 pool_size=None,
                 params=None,
                 recreate=False,
                 *args,
                 **kwargs):

        if params is not None:
            batch_size = params['batch_size']
            split_num = params['split_num']  # TODO: check split num
            class_num = params['class_num']
        # pool_size = split_num // pick_num
        # name = 'Fashion_pool_' + str(pool_size) + 'split_' + str(split_num) + 'pick' + str(
        #     pick_num) + '_batchsize_' + str(batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
        # nickname = 'fashion B' + str(batch_size) + ' S' + str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
        name = f'PACS_pool_4_split_{split_num}_class_{class_num}_batchsize_{batch_size}'
        nickname = None
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)


        file_path = os.path.join(utils.data_folder_path, name)
        save_path = os.path.join(utils.pool_folder_path, f'{name}.npy')

        transform = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        def trans(examples):
            examples['image'] = [transform(image.convert("RGB")) for image in examples['image']]
            return examples

        if os.path.exists(save_path) and (recreate == False):
            data_loader = np.load(save_path, allow_pickle=True).item()  # 导入对象
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:
            self.name = name
            self.domains = ['art_painting', 'cartoon', 'photo', 'sketch']
            self.classes = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
            self.pool_size = len(self.domains)

            # assert self.domains == self.pool_size, "domain size doesn't match pool_size"
            self.input_data_shape = [3, 224, 224]
            self.target_class_num = params['class_num']
            # self.output_size = 345
            self.total_training_number = 0
            self.total_test_number = 0
            self.server_data = {}
            self.server_data_number = {}
            self.statistic = {}



            # TODO: 重新分数据，把同一个domaind的数据分出去，自己占大头，小头扔池子里
            cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), 'hf-cache')
            local_datas, pool_datas = [], []

            try:
                dataset = load_from_disk(os.path.join(utils.data_folder_path, 'PACS', f'{self.target_class_num}'))
            except:
                print(f'PACS target class num {self.target_class_num} is not found, now begin to create')
                dataset = load_dataset(os.path.join(utils.data_folder_path, 'PACS', 'pacs'), split='train')
                dataset = dataset.filter(lambda example: example['label'] < self.target_class_num, num_proc=16)
                new_features = dataset.features.copy()
                new_features["domain"] = ClassLabel(num_classes=4, names=self.domains)
                dataset = dataset.cast(new_features)
                dataset = dataset.train_test_split(test_size=0.3)                # 分为训练和测试两部分
                dataset.save_to_disk(os.path.join(utils.data_folder_path, 'PACS', f'{self.target_class_num}'))

            # domain_dict = dataset.features['domain']

            # dataset.set_transform(trans)

            train_dataset, test_dataset = dataset['train'], dataset['test']
            self.domain_dict = train_dataset.features['domain']
            self.label_dict = train_dataset.features['label']

            label_split = math.ceil(len(self.classes) * 0.7)

            for domain in self.domains:
                d_train_data = train_dataset.filter(lambda example: example["domain"] == self.domain_dict.str2int(domain), num_proc=16)
                d_test_data = test_dataset.filter(lambda example: example["domain"] == self.domain_dict.str2int(domain), num_proc=16)

                # to balance data
                print(domain, 'train', len(d_train_data))
                print(domain, 'test', len(d_test_data))

                if len(d_train_data) > 2000: # 保证每个domain的数据在4000内
                    d_train_data = d_train_data.shuffle().select(range(2000)) # 这里没问题
                if len(d_test_data) > 1000:
                    d_test_data = d_test_data.shuffle().select(range(1000))

                if split_num == 2:
                    local_train_data = d_train_data.filter(lambda example: example['label'] < label_split, num_proc=16)
                    local_test_data = d_test_data.filter(lambda example: example['label'] < label_split, num_proc=16)
                    pool_train_data = d_train_data.filter(lambda example: example['label'] >= label_split, num_proc=16)
                    pool_test_data = d_test_data.filter(lambda example: example['label'] >= label_split, num_proc=16)
                elif split_num == 1:
                    local_train_data = d_train_data
                    local_test_data = d_test_data
                    pool_train_data, pool_test_data = Dataset.from_dict({"image": [], "domain": [], "label": []}), Dataset.from_dict({"image": [], "domain": [], "label": []})  # 根据实际列名和数据类型调整

                else:
                    raise ValueError('Split num must be 2 or 1.')
                local_datas.append([local_train_data, local_test_data])
                pool_datas.append([pool_train_data, pool_test_data])

                # d_data = dataset.filter(lambda example: example["domain"] == domain_dict.str2int(domain))

                # if len(d_data) > 2000:
                #     d_data = d_data.shuffle(seed=42).select(range(2000))
                # local_data = d_data.filter(lambda example: example['label'] < label_split).train_test_split(test_size=0.3)
                # pool_data = d_data.select(range(len(local_data['train']) + len(local_data['test']), len(d_data))).train_test_split(test_size=0.3)
                # # data_vol = len(local_data['train']) + len(local_data['test']) + len(pool_data['train']) + len(pool_data['test'])
                # local_datas.append(local_data)
                # pool_datas.append(pool_data)
            def create_data_pool(data_pool):
                for pool_idx in range(self.pool_size):
                    # local_datas[idx] + pool_datas[(idx+1)%self.pool_size]
                    local_training_data = DatasetDict({
                        'train': concatenate_datasets(
                            # [local_datas[pool_idx]['train'], pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['train']]),
                            [local_datas[pool_idx][0], pool_datas[(pool_idx + 1) % self.pool_size][0]]),
                        'test': concatenate_datasets(
                            # [local_datas[pool_idx]['test'], pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['test']])
                            [local_datas[pool_idx][1], pool_datas[(pool_idx + 1) % self.pool_size][1]])

                    })
                    # local_training_data.save_to_disk(os.path.join(file_path, f'local_{pool_idx}_data'))

                    data_pool[pool_idx]['local_training_data'] = local_training_data['train']
                    print(self.domain_dict.int2str(set(local_training_data['train']['domain'])))
                    data_pool[pool_idx]['local_training_domain'] = self.domain_dict.int2str(
                        set(local_training_data['train']['domain']))
                    data_pool[pool_idx]['local_test_data'] = local_training_data['test']
                    data_pool[pool_idx]['local_test_domain'] = self.domain_dict.int2str(
                        set(local_training_data['test']['domain']))
                    print(self.domain_dict.int2str(set(local_training_data['test']['domain'])))
                    data_pool[pool_idx]['local_training_number'] = len(local_training_data['train'])
                    print(len(local_training_data['train']))
                    data_pool[pool_idx]['local_test_number'] = len(local_training_data['test'])
                    print(len(local_training_data['test']))
                    # domain_name: [train_num, test_num]
                    data_pool[pool_idx]['local_statistic'] = {
                        self.domains[pool_idx]: [len(local_datas[pool_idx][0]),
                                                 len(local_datas[pool_idx][1])],
                        self.domains[(pool_idx + 1) % self.pool_size]: [
                            len(pool_datas[(pool_idx + 1) % self.pool_size][0]),
                            len(pool_datas[(pool_idx + 1) % self.pool_size][1])]}

            data_pool = [{} for _ in range(self.pool_size)]
            # local_training local_test
            create_data_pool(data_pool)
            self.data_pool = data_pool
            # print(self.statistic)
            np.save(save_path, self)
        for pool in self.data_pool:
            pool['local_training_data'].set_transform(trans)
            pool['local_test_data'].set_transform(trans)

    def allocate(self, client_list):
        choose_data_pool_item_indices = list(range(self.pool_size))
        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'],
                               data_pool_item['local_statistic'])

    #         def create_data_pool(data_pool):
    #             for pool_idx in range(self.pool_size):
    #                 # local_datas[idx] + pool_datas[(idx+1)%self.pool_size]
    #                 local_training_data = DatasetDict({1
    #                     'train': concatenate_datasets(
    #                         # [local_datas[pool_idx]['train'], pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['train']]),
    #                         [local_datas[pool_idx]['train'], pool_datas[(pool_idx + 1) % self.pool_size]['train']]),
    #                     'test': concatenate_datasets(
    #                         # [local_datas[pool_idx]['test'], pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['test']])
    #                         [local_datas[pool_idx]['test'], pool_datas[(pool_idx + 1) % self.pool_size]['test']])
    #
    #                 })
    #                 # local_training_data.save_to_disk(os.path.join(file_path, f'local_{pool_idx}_data'))
    #
    #                 data_pool[pool_idx]['local_training_data'] = local_training_data['train']
    #                 print(self.domain_dict.int2str(set(local_training_data['train']['domain'])))
    #                 data_pool[pool_idx]['local_training_domain'] = self.domain_dict.int2str(set(local_training_data['train']['domain']))
    #                 data_pool[pool_idx]['local_test_data'] = local_training_data['test']
    #                 data_pool[pool_idx]['local_test_domain'] = self.domain_dict.int2str(set(local_training_data['test']['domain']))
    #                 print(self.domain_dict.int2str(set(local_training_data['test']['domain'])))
    #                 data_pool[pool_idx]['local_training_number'] = len(local_training_data['train'])
    #                 print(len(local_training_data['train']))
    #                 data_pool[pool_idx]['local_test_number'] = len(local_training_data['test'])
    #                 print(len(local_training_data['test']))
    #                 # domain_name: [train_num, test_num]
    #                 data_pool[pool_idx]['local_statistic'] = {self.domains[pool_idx]: [len(local_datas[pool_idx]['train']), len(local_datas[pool_idx]['test'])],
    #                                                           self.domains[(pool_idx + 1) % self.pool_size]: [len(pool_datas[(pool_idx + 1) % self.pool_size]['train']), len(pool_datas[(pool_idx + 1) % self.pool_size]['test'])]}
    #
    #
    #                                                           # self.domains[(pool_idx + len(self.domains)-1) % self.pool_size]: [len(pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['train']),
    #                                                           #                                                                   len(pool_datas[(pool_idx + len(self.domains)-1) % self.pool_size]['test'])]}
    #
    #
    #                                         # self.statistic[f'Client {pool_idx}'] = {'Main Domain': self.domains[pool_idx], 'Main vols': [len(all_train_data[pool_idx]), len(all_test_data[pool_idx])],
    #                 #                             'Sub Domain': self.domains[(pool_idx+self.pool_size-1) % self.pool_size], 'Sub vols': [len(all_train_pool[(pool_idx+1) % self.pool_size]), len(all_test_pool[(pool_idx+1) % self.pool_size])]}
    #
    #
    #         data_pool = [{} for _ in range(self.pool_size)]
    #         # local_training local_test
    #         create_data_pool(data_pool)
    #         self.data_pool = data_pool
    #
    #         # print(self.statistic)
    #         np.save(save_path, self)
    #     # convert PIL to torch.tensor
    #     for pool in self.data_pool:
    #         pool['local_training_data'].set_transform(trans)
    #         pool['local_test_data'].set_transform(trans)
    #
    # def allocate(self, client_list):
    #     choose_data_pool_item_indices = list(range(self.pool_size))
    #     for idx, client in enumerate(client_list):
    #         data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
    #         client.update_data(choose_data_pool_item_indices[idx],
    #                            data_pool_item['local_training_data'],
    #                            data_pool_item['local_training_number'],
    #                            data_pool_item['local_test_data'],
    #                            data_pool_item['local_test_number'],
    #                            data_pool_item['local_statistic'])


            # train_data, test_data, server_data = domain_data[train_indices], domain_data[test_indices], domain_data[server_indices]
            # train_targets, test_targets, server_targets = domain_targets[train_indices], domain_targets[test_indices], domain_targets[server_indices]
            #
            # batch_train_data = DataLoader.separate_list(train_data, 256)
            # batch_test_data = DataLoader.separate_list(test_data, 256)
            #
            # batch_train_targets = DataLoader.separate_list(train_targets, 256)
            # batch_test_targets = DataLoader.separate_list(test_targets, 256)

            # 注意用Subset构造会失去部分dataset类的功能
            # train_dataset = torch.utils.data.Subset(dataset, train_indices)
            # test_dataset = torch.utils.data.Subset(dataset, test_indices)
            # server_dataset = torch.utils.data.Subset(dataset, server_indices)

            # local_train_dataloader.append(torch.utils.data.DataLoader(train_dataset, batch_size=256, num_workers=8, shuffle=True, pin_memory=True))
            # local_test_dataloader.append(torch.utils.data.DataLoader(test_dataset, batch_size=256, num_workers=8, shuffle=False, pin_memory=True))
